{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ddfa9ae6-69fe-444a-b994-8c4c5970a7ec",
   "metadata": {},
   "source": [
    "# Week 2 Exercise - with Booking, Translation and Speech-To-Text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ccbf174-a724-46a8-9db4-addd249923a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Note: The speech-to-text functionality requires FFmpeg to be installed. Go to FFmpeg website and downoad the corresponding OS installer.\n",
    "# !pip install openai-whisper sounddevice scipy numpy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b50bbe2-c0b1-49c3-9a5c-1ba7efa2bcb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports\n",
    "\n",
    "import os\n",
    "import json\n",
    "from dotenv import load_dotenv\n",
    "from openai import OpenAI\n",
    "import gradio as gr\n",
    "from anthropic import Anthropic\n",
    "import numpy as np\n",
    "import sounddevice as sd\n",
    "import scipy.io.wavfile as wav\n",
    "import tempfile\n",
    "import whisper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "747e8786-9da8-4342-b6c9-f5f69c2e22ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialization\n",
    "load_dotenv(override=True)\n",
    "openai_api_key = os.getenv('OPENAI_API_KEY')\n",
    "anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')\n",
    "# Initialize clients\n",
    "MODEL = \"gpt-4o-mini\"\n",
    "STT_DURATION = 3\n",
    "openai = OpenAI()\n",
    "anthropic = Anthropic(api_key=anthropic_api_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a521d84-d07c-49ab-a0df-d6451499ed97",
   "metadata": {},
   "outputs": [],
   "source": [
    "system_message = \"You are a helpful assistant for an Airline called FlightAI. \"\n",
    "system_message += \"Give short, courteous answers, no more than 1 sentence. \"\n",
    "system_message += \"Always be accurate. If you don't know the answer, say so.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0696acb1-0b05-4dc2-80d5-771be04f1fb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get ticket price function\n",
    "\n",
    "ticket_prices = {\"london\": \"$799\", \"paris\": \"$899\", \"tokyo\": \"$1400\", \"berlin\": \"$499\", \"rome\": \"$699\", \"bucharest\": \"$949\", \"moscow\": \"$1199\"}\n",
    "\n",
    "def get_ticket_price(destination_city):\n",
    "    print(f\"Tool get_ticket_price called for {destination_city}\")\n",
    "    city = destination_city.lower()\n",
    "    return ticket_prices.get(city, \"Unknown\")\n",
    "\n",
    "# create booking function\n",
    "import random\n",
    "\n",
    "def create_booking(destination_city):\n",
    "    # Generate a random 6-digit number\n",
    "    digits = ''.join([str(random.randint(0, 9)) for _ in range(6)])  \n",
    "    booking_number = f\"AI{digits}\"\n",
    "    \n",
    "    # Print the booking confirmation message\n",
    "    print(f\"Booking {booking_number} created for the flight to {destination_city}\")\n",
    "    \n",
    "    return booking_number"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4afceded-7178-4c05-8fa6-9f2085e6a344",
   "metadata": {},
   "outputs": [],
   "source": [
    "# price function structure:\n",
    "\n",
    "price_function = {\n",
    "    \"name\": \"get_ticket_price\",\n",
    "    \"description\": \"Get the price of a return ticket to the destination city. Call this whenever you need to know the ticket price, for example when a customer asks 'How much is a ticket to this city'\",\n",
    "    \"parameters\": {\n",
    "        \"type\": \"object\",\n",
    "        \"properties\": {\n",
    "            \"destination_city\": {\n",
    "                \"type\": \"string\",\n",
    "                \"description\": \"The city that the customer wants to travel to\",\n",
    "            },\n",
    "        },\n",
    "        \"required\": [\"destination_city\"],\n",
    "        \"additionalProperties\": False\n",
    "    }\n",
    "}\n",
    "\n",
    "# booking function structure:\n",
    "booking_function = {\n",
    "    \"name\": \"make_booking\",\n",
    "    \"description\": \"Make a flight booking for the customer. Call this whenever a customer wants to book a flight to a destination.\",\n",
    "    \"parameters\": {\n",
    "        \"type\": \"object\",\n",
    "        \"properties\": {\n",
    "            \"destination_city\": {\n",
    "                \"type\": \"string\",\n",
    "                \"description\": \"The city that the customer wants to travel to\",\n",
    "            },\n",
    "        },\n",
    "        \"required\": [\"destination_city\"],\n",
    "        \"additionalProperties\": False\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdca8679-935f-4e7f-97e6-e71a4d4f228c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# List of tools:\n",
    "\n",
    "tools = [\n",
    "    {\"type\": \"function\", \"function\": price_function},\n",
    "    {\"type\": \"function\", \"function\": booking_function}\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0992986-ea09-4912-a076-8e5603ee631f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function handle_tool_call:\n",
    "\n",
    "def handle_tool_call(message):\n",
    "    tool_call = message.tool_calls[0]\n",
    "    function_name = tool_call.function.name\n",
    "    arguments = json.loads(tool_call.function.arguments)\n",
    "    \n",
    "    if function_name == \"get_ticket_price\":\n",
    "        city = arguments.get('destination_city')\n",
    "        price = get_ticket_price(city)\n",
    "        response = {\n",
    "            \"role\": \"tool\",\n",
    "            \"content\": json.dumps({\"destination_city\": city,\"price\": price}),\n",
    "            \"tool_call_id\": tool_call.id\n",
    "        }\n",
    "        return response, city\n",
    "    elif function_name == \"make_booking\":\n",
    "        city = arguments.get('destination_city')\n",
    "        booking_number = create_booking(city)\n",
    "        response = {\n",
    "            \"role\": \"tool\",\n",
    "            \"content\": json.dumps({\"destination_city\": city, \"booking_number\": booking_number}),\n",
    "            \"tool_call_id\": tool_call.id\n",
    "        }\n",
    "        return response, city"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "773a9f11-557e-43c9-ad50-56cbec3a0f8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Image generation\n",
    "\n",
    "import base64\n",
    "from io import BytesIO\n",
    "from PIL import Image\n",
    "\n",
    "def artist(city, testing_mode=False):\n",
    "    if testing_mode:\n",
    "        print(f\"Image generation skipped for {city} - in testing mode\")\n",
    "        return None\n",
    "    \n",
    "    image_response = openai.images.generate(\n",
    "            model=\"dall-e-3\",\n",
    "            prompt=f\"An image representing a vacation in {city}, showing tourist spots and everything unique about {city}, in a realistic style\",\n",
    "            size=\"1024x1024\",\n",
    "            n=1,\n",
    "            response_format=\"b64_json\",\n",
    "        )\n",
    "    image_base64 = image_response.data[0].b64_json\n",
    "    image_data = base64.b64decode(image_base64)\n",
    "    return Image.open(BytesIO(image_data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d1519a8-98ed-4673-ade0-aaba6341f155",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Text to speech \n",
    "\n",
    "import base64\n",
    "from io import BytesIO\n",
    "from PIL import Image\n",
    "from IPython.display import Audio, display\n",
    "\n",
    "def talker(message, testing_mode=False):\n",
    "    \"\"\"Generate speech from text and return the path to the audio file for Gradio to play\"\"\"\n",
    "    if testing_mode:\n",
    "        print(f\"Text-to-speech skipped - in testing mode\")\n",
    "        return None\n",
    "    \n",
    "    try:\n",
    "        response = openai.audio.speech.create(\n",
    "            model=\"tts-1\",\n",
    "            voice=\"onyx\",\n",
    "            input=message)\n",
    "\n",
    "        # Save to a unique filename based on timestamp to avoid caching issues\n",
    "        import time\n",
    "        timestamp = int(time.time())\n",
    "        output_filename = f\"output_audio_{timestamp}.mp3\"\n",
    "        \n",
    "        with open(output_filename, \"wb\") as f:\n",
    "            f.write(response.content)\n",
    "        \n",
    "        print(f\"Audio saved to {output_filename}\")\n",
    "        return output_filename\n",
    "    except Exception as e:\n",
    "        print(f\"Error generating speech: {e}\")\n",
    "        return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68149e08-d2de-4790-914a-6def79ff5612",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Speech to text function\n",
    "\n",
    "def recorder_and_transcriber(duration=STT_DURATION, samplerate=16000, testing_mode=False):\n",
    "    \"\"\"Record audio for the specified duration and transcribe it using Whisper\"\"\"\n",
    "    if testing_mode:\n",
    "        print(\"Speech-to-text skipped - in testing mode\")\n",
    "        return \"This is a test speech input\"\n",
    "    \n",
    "    print(f\"Recording for {duration} seconds...\")\n",
    "    \n",
    "    # Record audio using sounddevice\n",
    "    recording = sd.rec(int(duration * samplerate), samplerate=samplerate, channels=1, dtype='float32')\n",
    "    sd.wait()  # Wait until recording is finished\n",
    "    \n",
    "    # Save the recording to a temporary WAV file\n",
    "    with tempfile.NamedTemporaryFile(suffix=\".wav\", delete=False) as temp_audio:\n",
    "        temp_filename = temp_audio.name\n",
    "        wav.write(temp_filename, samplerate, recording)\n",
    "    \n",
    "    # Load Whisper model and transcribe\n",
    "    model = whisper.load_model(\"base\")  # You can use \"tiny\", \"base\", \"small\", \"medium\", or \"large\"\n",
    "    result = model.transcribe(temp_filename)\n",
    "    \n",
    "    # Clean up the temporary file\n",
    "    import os\n",
    "    os.unlink(temp_filename)\n",
    "    \n",
    "    return result[\"text\"].strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf1d5600-8df8-4cc2-8bf5-b0b33818b385",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import glob\n",
    "\n",
    "def cleanup_audio_files():\n",
    "    \"\"\"Delete all MP3 files in the current directory that match our output pattern\"\"\"\n",
    "    \n",
    "    # Get all mp3 files that match our naming pattern\n",
    "    mp3_files = glob.glob(\"output_audio_*.mp3\")\n",
    "    \n",
    "    # Delete each file\n",
    "    count = 0\n",
    "    for file in mp3_files:\n",
    "        try:\n",
    "            os.remove(file)\n",
    "            count += 1\n",
    "        except Exception as e:\n",
    "            print(f\"Error deleting {file}: {e}\")\n",
    "    \n",
    "    print(f\"Cleaned up {count} audio files\")\n",
    "    return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44a6f8e0-c111-4e40-a5ae-68dd0aa9f65d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Translation function\n",
    "\n",
    "def translate_text(text, target_language):\n",
    "    if not text or not target_language:\n",
    "        return \"\"\n",
    "        \n",
    "    # Map the language dropdown values to language names for Claude\n",
    "    language_map = {\n",
    "        \"French\": \"French\",\n",
    "        \"Spanish\": \"Spanish\",\n",
    "        \"German\": \"German\",\n",
    "        \"Italian\": \"Italian\",\n",
    "        \"Russian\": \"Russian\",\n",
    "        \"Romanian\": \"Romanian\"\n",
    "    }\n",
    "    \n",
    "    full_language_name = language_map.get(target_language, \"French\")\n",
    "    \n",
    "    try:\n",
    "        response = anthropic.messages.create(\n",
    "            model=\"claude-3-haiku-20240307\",\n",
    "            max_tokens=1024,\n",
    "            messages=[\n",
    "                {\n",
    "                    \"role\": \"user\",\n",
    "                    \"content\": f\"Translate the following text to {full_language_name}. Provide only the translation, no explanations: \\n\\n{text}\"\n",
    "                }\n",
    "            ]\n",
    "        )\n",
    "        return response.content[0].text\n",
    "    except Exception as e:\n",
    "        print(f\"Translation error: {e}\")\n",
    "        return f\"[Translation failed: {str(e)}]\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba820c95-02f5-499e-8f3c-8727ee0a6c0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def chat(history, image, testing_mode=False):\n",
    "    messages = [{\"role\": \"system\", \"content\": system_message}] + history\n",
    "    response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)    \n",
    "    \n",
    "    if response.choices[0].finish_reason==\"tool_calls\":\n",
    "        message = response.choices[0].message\n",
    "        response, city = handle_tool_call(message)\n",
    "        messages.append(message)\n",
    "        messages.append(response)\n",
    "        \n",
    "        # Only generate image if not in testing mode\n",
    "        if not testing_mode and image is None:\n",
    "            image = artist(city, testing_mode)\n",
    "            \n",
    "        response = openai.chat.completions.create(model=MODEL, messages=messages)\n",
    "        \n",
    "    reply = response.choices[0].message.content\n",
    "    history += [{\"role\":\"assistant\", \"content\":reply}]    \n",
    "\n",
    "    # Return the reply directly - we'll handle TTS separately\n",
    "    return history, image, reply"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3cc58f3-d0fc-47d1-b9cf-e5bf4c5edbdc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to translate conversation history\n",
    "def translate_history(history, target_language):\n",
    "    if not history or not target_language:\n",
    "        return []\n",
    "    \n",
    "    translated_history = []\n",
    "    \n",
    "    for msg in history:\n",
    "        role = msg[\"role\"]\n",
    "        content = msg[\"content\"]\n",
    "        \n",
    "        translated_content = translate_text(content, target_language)\n",
    "        translated_history.append({\"role\": role, \"content\": translated_content})\n",
    "    \n",
    "    return translated_history"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f38d0d27-33bf-4992-a2e5-5dbed973cde7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Update the Gradio interface to handle audio output properly\n",
    "def update_gradio_interface():\n",
    "    with gr.Blocks() as ui:\n",
    "        # Store chat history and audio output in state\n",
    "        state = gr.State([])\n",
    "        audio_state = gr.State(None)\n",
    "        \n",
    "        with gr.Row():\n",
    "            testing_checkbox = gr.Checkbox(label=\"Testing\", info=\"Turn off multimedia features when checked\", value=False)\n",
    "        \n",
    "        with gr.Row():\n",
    "            with gr.Column(scale=2):\n",
    "                # Main panel with original chat and image\n",
    "                with gr.Row():\n",
    "                    with gr.Column(scale=1):\n",
    "                        with gr.Row():\n",
    "                            chatbot = gr.Chatbot(height=300, type=\"messages\")\n",
    "                        with gr.Row():\n",
    "                            language_dropdown = gr.Dropdown(\n",
    "                                choices=[\"French\", \"Spanish\", \"German\", \"Italian\", \"Russian\", \"Romanian\"],\n",
    "                                value=\"French\",\n",
    "                                label=\"Translation to\"\n",
    "                            )\n",
    "                        with gr.Row():\n",
    "                            translation_output = gr.Chatbot(height=200, type=\"messages\", label=\"Translated chat\")\n",
    "                    with gr.Column(scale=1):\n",
    "                        with gr.Row():\n",
    "                            image_output = gr.Image(height=620)\n",
    "                        with gr.Row():\n",
    "                            audio_output = gr.Audio(label=\"Assistant's Voice\", visible=False, autoplay=True, type=\"filepath\")\n",
    "                                        \n",
    "        with gr.Row():\n",
    "            entry = gr.Textbox(label=\"Chat with our AI Assistant:\")\n",
    "                    \n",
    "        with gr.Row():\n",
    "            with gr.Column(scale=1):\n",
    "                with gr.Row():\n",
    "                    md = gr.Markdown()\n",
    "            with gr.Column(scale=1):\n",
    "                speak_button = gr.Button(value=\"🎤 Speak Command\", variant=\"primary\")\n",
    "            with gr.Column(scale=1):\n",
    "                with gr.Row():\n",
    "                    md = gr.Markdown()\n",
    "            with gr.Column(scale=1):            \n",
    "                with gr.Row():\n",
    "                    clear = gr.Button(value=\"Clear\", variant=\"secondary\")\n",
    "            with gr.Column(scale=1):\n",
    "                with gr.Row():\n",
    "                    md = gr.Markdown()\n",
    "\n",
    "        # Function to handle speech input\n",
    "        def do_speech_input(testing_mode):\n",
    "            # Record and transcribe speech\n",
    "            speech_text = recorder_and_transcriber(duration=STT_DURATION, testing_mode=testing_mode)\n",
    "            return speech_text\n",
    "            \n",
    "        # Function to handle user input\n",
    "        def do_entry(message, history, testing_mode):\n",
    "            history += [{\"role\":\"user\", \"content\":message}]\n",
    "            return \"\", history\n",
    "        \n",
    "        # Function to handle translation updates\n",
    "        def do_translation(history, language):\n",
    "            translated = translate_history(history, language)\n",
    "            return translated\n",
    "        \n",
    "        # Function to handle text-to-speech\n",
    "        def do_tts(reply, testing_mode):\n",
    "            if not reply or testing_mode:\n",
    "                return None\n",
    "            \n",
    "            audio_file = talker(reply, testing_mode)\n",
    "            return audio_file\n",
    "        \n",
    "        # Handle user message submission\n",
    "        entry.submit(do_entry, inputs=[entry, chatbot, testing_checkbox], outputs=[entry, chatbot]).then(\n",
    "            chat, inputs=[chatbot, image_output, testing_checkbox], outputs=[chatbot, image_output, audio_state]\n",
    "        ).then(\n",
    "            do_tts, inputs=[audio_state, testing_checkbox], outputs=[audio_output]\n",
    "        ).then(\n",
    "            do_translation, inputs=[chatbot, language_dropdown], outputs=[translation_output]\n",
    "        )\n",
    "        \n",
    "        # Add speech button handling\n",
    "        speak_button.click(\n",
    "            do_speech_input, \n",
    "            inputs=[testing_checkbox], \n",
    "            outputs=[entry]\n",
    "        ).then(\n",
    "            do_entry, \n",
    "            inputs=[entry, chatbot, testing_checkbox], \n",
    "            outputs=[entry, chatbot]\n",
    "        ).then(\n",
    "            chat, \n",
    "            inputs=[chatbot, image_output, testing_checkbox], \n",
    "            outputs=[chatbot, image_output, audio_state]\n",
    "        ).then(\n",
    "            do_tts, inputs=[audio_state, testing_checkbox], outputs=[audio_output]\n",
    "        ).then(\n",
    "            do_translation, \n",
    "            inputs=[chatbot, language_dropdown], \n",
    "            outputs=[translation_output]\n",
    "        )\n",
    "        \n",
    "        # Update translation when language is changed\n",
    "        language_dropdown.change(do_translation, inputs=[chatbot, language_dropdown], outputs=[translation_output])\n",
    "        \n",
    "        # Handle clear button\n",
    "        def clear_all():\n",
    "            # Clean up audio files\n",
    "            cleanup_audio_files()\n",
    "            # Return None for all outputs to clear the UI\n",
    "            return None, None, None, None\n",
    "        \n",
    "        clear.click(clear_all, inputs=None, outputs=[chatbot, translation_output, image_output, audio_output], queue=False)\n",
    "\n",
    "    return ui\n",
    "\n",
    "# Replace the original ui code with this:\n",
    "ui = update_gradio_interface()\n",
    "ui.launch(inbrowser=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
